{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算下中位数\n",
    "def getDownMedian(data):\n",
    "    if data.shape[0] % 2 == 0:\n",
    "        data = np.hstack([data, np.inf])\n",
    "    return np.median(data)\n",
    "    \n",
    "def seperateNode(node, depth):\n",
    "    dataSet = node['data']\n",
    "    if dataSet.shape[0] == 0:\n",
    "        return None\n",
    "    feature = depth % dataSet.shape[1]\n",
    "    value = getDownMedian(dataSet[:, feature])\n",
    "    node['left'] = {'data':dataSet[dataSet[:,feature] < value, :],\n",
    "                    'left':None, 'right':None, 'parent':node}\n",
    "    node['right'] = {'data':dataSet[dataSet[:,feature] > value, :],\n",
    "                     'left':None, 'right':None, 'parent':node}\n",
    "    node['data'] = dataSet[dataSet[:,feature] == value, :][0]\n",
    "    node['feature'] = feature\n",
    "    seperateNode(node['left'], depth+1)\n",
    "    seperateNode(node['right'], depth+1)\n",
    "    return node\n",
    "    \n",
    "def printNode(node, depth):\n",
    "    if node == None or node['data'].shape[0]==0:\n",
    "        return\n",
    "    print (\" \"*depth, node['data'])\n",
    "    printNode(node['left'], depth+1)\n",
    "    printNode(node['right'], depth+1)\n",
    "    \n",
    "    \n",
    "def createTree(dataSet):\n",
    "    root = {'data':dataSet, 'left':None, 'right':None}\n",
    "    seperateNode(root, 0)\n",
    "    return root "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " [7 2]\n",
      "  [5 4]\n",
      "   [2 3]\n",
      "   [4 7]\n",
      "  [9 6]\n",
      "   [8 1]\n"
     ]
    }
   ],
   "source": [
    "root = createTree(np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]))\n",
    "printNode(root, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def isLeaf(node):\n",
    "    return True\n",
    "\n",
    "# 在kd树中找出包含目标点x的叶结点\n",
    "def findLeaf(node, depth, x):\n",
    "    # 从根结点出发，递归地向下访问kd树\n",
    "    while True:\n",
    "        feature = node['feature']\n",
    "        # 若目标点x当前维的坐标小于切分点的坐标\n",
    "        if x[feature] < node['data'][feature]:\n",
    "            # 则移动到左子节点\n",
    "            subnode = node['left']\n",
    "        # 否则移动到右子节点\n",
    "        else:\n",
    "            subnode = node['right']\n",
    "        # 直至子结点为叶结点为止\n",
    "        if subnode['data'].shape[0] == 0:\n",
    "            return node\n",
    "        node = subnode\n",
    "\n",
    "def distance(A, B):\n",
    "    return np.linalg.norm(A - B)\n",
    "\n",
    "def check(feature, value, x):\n",
    "    return np.abs(x[feature] - value[feature])\n",
    "\n",
    "def isSame(a, b):\n",
    "    return a[0] == b[0] and a[1] == b[1]\n",
    "\n",
    "def findAnother(node):\n",
    "    parent = node['parent']\n",
    "    if isSame(node['data'], parent['left']['data']):\n",
    "        return parent['right']\n",
    "    else:\n",
    "        return parent['left']\n",
    "        \n",
    "def search(root, x):\n",
    "    # 在kd树中找出包含目标点x的叶结点\n",
    "    current = findLeaf(root, 0, x)\n",
    "    # 以此结点为当前最近点\n",
    "    nearest = current['data']\n",
    "    currentDistance = distance(nearest, x)\n",
    "    # 递归地向上回退，在每个结点进行以下操作\n",
    "    # 当回退到根结点时搜索结束\n",
    "    while not isSame(current['data'], root['data']):\n",
    "        parent = current['parent']\n",
    "        dis = distance(parent['data'], x)\n",
    "        # 如果该结点保存的实例点比当前最近点距离目标点更近\n",
    "        if dis < currentDistance:\n",
    "            # 则以该实例点为“当前最近点”\n",
    "            nearest = parent['data']\n",
    "            currentDistance = dis\n",
    "        # 检查该子结点的父结点的另一个子结点对应的区域是否有更近的点\n",
    "        # 即目标结点到平面feature=value的距离\n",
    "        # 如果相交，可能在另一个子结点对应的区域内存在距目标点更近的点\n",
    "        if check(parent['feature'], parent['data'], x) < currentDistance:\n",
    "            # 递归的进行最近邻搜索\n",
    "            newnode, dis = search(findAnother(current), x)\n",
    "            if dis < currentDistance:\n",
    "                nearest = newnode\n",
    "                currentDistance = dis\n",
    "        current = parent  \n",
    "    return nearest, currentDistance\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([9, 6]), 1.4142135623730951)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "root = createTree(np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]))\n",
    "search(root, np.array([8, 7]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
